{
 "cells": [
  {
   "cell_type": "code",
   "id": "initial_id",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2025-02-07T15:28:16.323926Z",
     "start_time": "2025-02-07T15:28:05.507488Z"
    }
   },
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import math\n",
    "from torch.nn import functional as F\n",
    "# 定义相对位置编码模块\n",
    "class RelativePositionEmbedding(nn.Module):\n",
    "    def __init__(self, max_seq_len: int, embedding_dim: int):\n",
    "        super().__init__()\n",
    "        self.max_seq_len = max_seq_len\n",
    "        self.embedding_dim = embedding_dim\n",
    "        self.relative_positions = nn.Parameter(torch.randn(max_seq_len * 2 - 1, embedding_dim))\n",
    "    \n",
    "    def forward(self, seq_len: int):\n",
    "        relative_positions_matrix = self._generate_relative_positions_matrix(seq_len)\n",
    "        relative_embeddings = F.embedding(relative_positions_matrix, self.relative_positions)\n",
    "        return relative_embeddings\n",
    "    \n",
    "    def _generate_relative_positions_matrix(self, seq_len: int):\n",
    "        range_vec = torch.arange(seq_len)\n",
    "        range_matrix = range_vec.unsqueeze(0).expand(seq_len, seq_len)\n",
    "        distance_matrix = range_matrix - range_matrix.t()\n",
    "        distance_matrix = distance_matrix + self.max_seq_len - 1\n",
    "        return distance_matrix\n",
    "# 定义多头注意力模块\n",
    "class MultiHeadAttention(nn.Module):\n",
    "    def __init__(self, embed_dim: int, num_heads: int, max_seq_len: int):\n",
    "        super().__init__()\n",
    "        self.embed_dim = embed_dim\n",
    "        self.num_heads = num_heads\n",
    "        self.head_dim = embed_dim // num_heads\n",
    "        self.q_proj = nn.Linear(embed_dim, embed_dim)\n",
    "        self.k_proj = nn.Linear(embed_dim, embed_dim)\n",
    "        self.v_proj = nn.Linear(embed_dim, embed_dim)\n",
    "        self.rel_pos_embeddings = RelativePositionEmbedding(max_seq_len, self.head_dim)\n",
    "        self.out_proj = nn.Linear(embed_dim, embed_dim)\n",
    "    \n",
    "    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):\n",
    "        batch_size, seq_len, embed_dim = query.size()\n",
    "        query = self.q_proj(query).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)\n",
    "        key = self.k_proj(key).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)\n",
    "        value = self.v_proj(value).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)\n",
    "        rel_pos_embeddings = self.rel_pos_embeddings(seq_len)\n",
    "        rel_pos_embeddings = rel_pos_embeddings.unsqueeze(0).expand(batch_size, seq_len, seq_len, self.head_dim)\n",
    "        query = query.unsqueeze(3)\n",
    "        key = key.unsqueeze(2)\n",
    "        attention_scores = torch.matmul(query, key.transpose(-1, -2))\n",
    "        attention_scores += torch.matmul(query, rel_pos_embeddings.transpose(-1, -2))\n",
    "        attention_scores = attention_scores / math.sqrt(self.head_dim)\n",
    "        attention_weights = F.softmax(attention_scores, dim=-1)\n",
    "        output = torch.matmul(attention_weights, value)\n",
    "        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)\n",
    "        output = self.out_proj(output)\n",
    "        return output\n",
    "# 示例使用\n",
    "max_seq_len = 10\n",
    "embed_dim = 64\n",
    "num_heads = 4\n",
    "batch_size = 2\n",
    "seq_len = 5\n",
    "\n",
    "query = torch.randn(batch_size, seq_len, embed_dim)\n",
    "key = torch.randn(batch_size, seq_len, embed_dim)\n",
    "value = torch.randn(batch_size, seq_len, embed_dim)\n",
    "\n",
    "attention = MultiHeadAttention(embed_dim, num_heads, max_seq_len)\n",
    "print(attention.rel_pos_embeddings._generate_relative_positions_matrix(5))\n",
    "print(attention.rel_pos_embeddings.relative_positions.shape)\n",
    "print(attention.rel_pos_embeddings(5).shape)\n",
    "# output = attention(query, key, value)\n",
    "# print(output.shape)  # 输出形状为 (batch_size, seq_len, embed_dim)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 9, 10, 11, 12, 13],\n",
      "        [ 8,  9, 10, 11, 12],\n",
      "        [ 7,  8,  9, 10, 11],\n",
      "        [ 6,  7,  8,  9, 10],\n",
      "        [ 5,  6,  7,  8,  9]])\n",
      "torch.Size([19, 16])\n",
      "torch.Size([5, 5, 16])\n"
     ]
    }
   ],
   "execution_count": 1
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
